Wavelet tree
A wavelet tree is a static data structure on an integer sequence that answers a rich family of range queries — -th smallest, count of elements less than a value, range frequency, range median, and more — each in time, where is the size of the value domain. Building the tree takes time and space. In competitive programming is almost always reduced to by coordinate compression, so both bounds become
The wavelet tree is closely related to the persistent segment tree, which answers the same queries with the same asymptotic complexity. In practice the wavelet tree is faster (better cache behaviour, no pointer overhead) and significantly simpler to implement correctly.
Description
Given an array of integers drawn from the range , a wavelet tree is a complete binary tree over the value range. Each node covers a value sub-range . At every internal node with
- Elements with value are routed to the left child
- Elements with value are routed to the right child
The node stores a prefix-count array , where is the number of the first elements (in the node's local ordering) that were routed left. With these counts every query can navigate the tree without storing the values explicitly.
Construction
struct WaveletTree {
int lo, hi;
WaveletTree *left = nullptr, *right = nullptr;
vector<int> cnt; // cnt[i] = # of first i elements routed to left child
// Build over A[from, to) with value range [lo, hi].
// The array is reordered in place during construction.
void build(int *from, int *to, int lo, int hi) {
this->lo = lo; this->hi = hi;
if (from >= to || lo == hi) return;
int mid = lo + (hi - lo) / 2;
cnt.reserve(to - from + 1);
cnt.push_back(0);
for (auto it = from; it != to; ++it)
cnt.push_back(cnt.back() + (*it <= mid));
auto pivot = stable_partition(from, to,
[mid](int x){ return x <= mid; });
left = new WaveletTree(); left->build(from, pivot, lo, mid);
right = new WaveletTree(); right->build(pivot, to, mid+1, hi);
}
stable_partition physically splits the array so that left-routed elements
come first; the left child then builds over that prefix, the right child over
the remainder.
Queries
All queries take a 1-indexed range . At each node, two values summarize the range:
elements of went left; the rest went right. In the left child the range maps to ; in the right child it maps to
-th smallest (kth(l, r, k)): At each node count how many elements in
went left. If does not exceed that count, recurse left; otherwise
subtract it from and recurse right. At a leaf the value is determined.
// k-th smallest element in A[l..r] (1-indexed, k is 1-based)
int kth(int l, int r, int k) {
if (lo == hi) return lo;
int lb = cnt[l-1], rb = cnt[r];
int inLeft = rb - lb;
if (k <= inLeft) return left->kth(lb+1, rb, k);
return right->kth(l-lb, r-rb, k-inLeft);
}
Count less than (countLess(l, r, v)): returns .
If is entirely outside the answer is trivial.
Otherwise, if the right subtree can contribute nothing
(all its values are ), so we recurse left only. If
all left elements are , so they all count, plus a
recursive call into the right subtree.
// # elements in A[l..r] with value < v
int countLess(int l, int r, int v) {
if (v <= lo) return 0;
if (v > hi) return r - l + 1;
int mid = lo + (hi - lo) / 2;
int lb = cnt[l-1], rb = cnt[r];
if (v <= mid) return left->countLess(lb+1, rb, v);
return (rb - lb) + right->countLess(l-lb, r-rb, v);
}
};
Range frequency of value in : countLess(l, r, v+1) - countLess(l, r, v)
Range median: kth(l, r, (r - l + 2) / 2)
Complexity
Each query descends exactly one root-to-leaf path of the value-range tree, which has height . Each step is . Hence every query is
Building involves one stable_partition pass at every level of the tree.
Each element participates in exactly one pass per level, giving total time. The arrays across all nodes at any single
level together hold exactly integers, so total space is also
Applications
- Range order statistics — finding the -th smallest element in a
subarray; counting how many elements fall in a value range (as
countLess(l, r, b+1) - countLess(l, r, a)). See Order statistics tree for the single-element analogue. - Sliding window median — for a window of fixed size the median is
kth(i, i+k-1, (k+1)/2), answered in per window position instead of the per insertion of a balanced BST approach (with a better constant and no pointer overhead). - Range count-distinct — number of distinct values in .
Define = last position before with the same value (0
if none). Build the wavelet tree on the array. Then the
count of distinct values in equals the count of positions with , which is exactly
countLess(l, r, l)on the wavelet tree — answered in - Predecessor / successor in a range — the largest value in
can be found with a binary-search variant of
countLess: if there is any element , computecountLess(l, r, v+1)and returnkth(l, r, k) - Bitvector rank/select — at each tree level the array is a rank structure; the wavelet tree generalises this to multi-valued alphabets, which underpins many suffix array and compressed index applications.
Variants
Coordinate compression
When values can be up to , map them to before building the tree:
// Returns a wavelet tree over A[0..n-1], compressed to [1, n].
// sorted_vals receives the sorted unique values so kth results can be decoded.
WaveletTree* buildCompressed(vector<int> &A, vector<int> &sorted_vals) {
sorted_vals = A;
sort(sorted_vals.begin(), sorted_vals.end());
sorted_vals.erase(unique(sorted_vals.begin(), sorted_vals.end()),
sorted_vals.end());
int sigma = sorted_vals.size();
vector<int> C = A;
for (int &x : C)
x = lower_bound(sorted_vals.begin(), sorted_vals.end(), x)
- sorted_vals.begin() + 1;
WaveletTree *wt = new WaveletTree();
wt->build(C.data(), C.data() + C.size(), 1, sigma);
return wt;
}
// Decode: original value = sorted_vals[ wt->kth(l, r, k) - 1 ]
// countLess for original threshold x:
// threshold = lower_bound(sorted_vals, x) - sorted_vals.begin() + 1
// wt->countLess(l, r, threshold)
Array-based (flat) implementation
The pointer-based tree above allocates a node per split. A flat layout stores all arrays level by level in a single vector, which improves cache performance and removes allocation overhead:
struct WaveletFlat {
int n, lo, hi, levels;
vector<vector<int>> cnt; // cnt[d][i] at depth d
WaveletFlat(vector<int> A, int lo, int hi) : n(A.size()), lo(lo), hi(hi) {
levels = 1;
while ((1 << levels) < hi - lo + 1) levels++;
cnt.assign(levels + 1, vector<int>(n + 1, 0));
vector<int> cur = A, nxt(n);
for (int d = 0; d < levels; d++) {
int range_lo = lo, range_hi = hi;
// midpoint for the root at this level needs per-node tracking;
// a simpler approach flattens by bit from MSB to LSB:
int bit = levels - 1 - d;
for (int i = 0; i < n; i++)
cnt[d][i+1] = cnt[d][i] + !((A[i] - lo) >> bit & 1);
// partition stably by bit `bit`
int li = 0, ri = cnt[d][n];
for (int i = 0; i < n; i++) {
if (!((A[i] - lo) >> bit & 1)) nxt[li++] = A[i];
else nxt[ri++] = A[i];
}
swap(A, nxt);
}
}
};
The flat implementation avoids new entirely and is typically 2–3× faster in
practice.
Range sum augmentation
To answer "sum of the smallest elements in ", augment each
node with a parallel sum prefix array alongside cnt, accumulating the
values of elements that go left. A sumKSmallest(l, r, k) query mirrors
kth: when elements fit in the left subtree, return the left subtree's
answer; otherwise add the full left sum and recurse right for the remainder.
The -th smallest value times its count plus sums of smaller elements can
all be computed without changing the per-query bound.
Problems
-th order statistics
Solution sketch — K-th Number (MKTHNUM)
This is the canonical wavelet tree problem: given and queries , output the -th smallest value in
Coordinate-compress the values to , build the wavelet tree, and for
each query call kth(l, r, k) in . Total time
Solution sketch — Sliding Median (CSES 1076)
For a sliding window of width , the median position is .
After building the wavelet tree on the full array, each window
contributes one query kth(i, i+k-1, p). Each query is , giving
total — the same as a balanced BST sliding window but with a
much smaller constant.
Range counting
Solution sketch — K Query (KQUERY)
Each query asks: how many elements in are greater than
This is . After
coordinate-compressing, countLess(i, j, threshold) where threshold is the
rank of in the sorted values runs in per query.
Count-distinct and range partitioning
Solution sketch — Till I Collapse (CF 786C)
For each , compute : the number of groups produced by greedily splitting into maximal-length contiguous segments each with distinct values.
Build a wavelet tree on the previous-occurrence array,
where is the last index with (or if none). Then the number of distinct values in
is countLess(l, r, l) on this tree — a single call.
For a fixed , simulate the greedy: starting at , binary search for
the rightmost such that countDistinct(l, r), advance ,
and increment the group counter. Each greedy step costs , and
there are at most steps per , giving an overall
bound when done offline
in the right order — or with the standard
divide-and-conquer on
See also
- Persistent segment tree — same asymptotic complexity; wavelet tree is the offline, space-efficient alternative
- Merge sort tree — stores a sorted list at each segment-tree node; answers the same queries in but is simpler to extend to updates
- Segment tree — the underlying range-decomposition idea
- Order statistics tree — handles point insertions/deletions but lacks arbitrary range queries
- Coordinate compression — essential preprocessing step when values exceed

